[Notebook] Dask and Xarray on AWS-HPC cluster: distributed processing of Earth data
This notebook continues the previous post by showing the actual code for distributed data processing.
%matplotlib inline
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import dask
from dask.diagnostics import ProgressBar
from dask_jobqueue import SLURMCluster
import distributed
from distributed import Client, progress
dask.__version__, distributed.__version__
%env HDF5_USE_FILE_LOCKING=FALSE
Data exploration¶
Data are organized by year/month:
ls /fsx
ls /fsx/2008/
ls /fsx/2008/01/data # one variable per file
# hourly data over a month
dr = xr.open_dataarray('/fsx/2008/01/data/sea_surface_temperature.nc')
dr
# Static plot of the first time slice
fig, ax = plt.subplots(1, 1, figsize=[12, 8], subplot_kw={'projection': ccrs.PlateCarree()})
dr[0].plot(ax=ax, transform=ccrs.PlateCarree(), cbar_kwargs={'shrink': 0.6})
ax.coastlines();
What happens to the values over the land? Easier to check by an interactive plot.
import geoviews as gv
import hvplot.xarray
fig_hv = dr[0].hvplot.quadmesh(
x='lon', y='lat', rasterize=True, cmap='viridis', geo=True,
crs=ccrs.PlateCarree(), projection=ccrs.PlateCarree(), project=True,
width=800, height=400,
) * gv.feature.coastline
# fig_hv
# This is just a hack to display figure on Nikola blog post
# If you know an easier way let me know
import holoviews as hv
from bokeh.resources import CDN, INLINE
from bokeh.embed import file_html
from IPython.display import HTML
HTML(file_html(hv.render(fig_hv), CDN))
So it turns out that the "temperature" over the land is set as 273.16K (0 degree celsius). A better way is probably masking them out.
Serial read with master node¶
Let 1-year data.
# Just querying metadata will cause files being pulled from S3 to FSx.
# This takes a while at first executation. Much faster at second time.
%time ds_1yr = xr.open_mfdataset('/fsx/2008/*/data/sea_surface_temperature.nc', chunks={'time0': 50})
dr_1yr = ds_1yr['sea_surface_temperature']
dr_1yr
The aggregated size is ~29 GB:
dr_1yr.nbytes / 1e9 # GB
with ProgressBar():
mean_1yr_ser = dr_1yr.mean().compute()
mean_1yr_ser
Parallel read with dask cluster¶
Cluster initialization¶
!sinfo # spin-up 8 idle nodes
!mkdir -p ./dask_tempdir
# Reference: https://jobqueue.dask.org/en/latest/configuration.html
# - "cores" is the number of CPUs used per Slurm job.
# Here fix it as 72, which is the number of vCPUs per c5n.18xl node. So one slurm job gets exactly one node.
# - "processes" specifies the number of dask workers in a single Slurm job.
# - "memory" specifies the memory requested in a single Slurm job.
cluster = SLURMCluster(cores=72, processes=36, memory='150GB',
local_directory='./dask_tempdir')
# 8 node * 36 workers/node
cluster.scale(8*36)
cluster
Visit http://localhost:8787 for the dashboard.
# remember to also create dask client to talk to the cluster!
client = Client(cluster) # automatically switches to distributed mode
client
# now the default scheduler is dask.distributed
dask.config.get('scheduler')
!sinfo # now fully allocated
!squeue # all are dask worker jobs, one per compute node
Read 1-year data¶
# Actually, no need to reopen files. Can just reuse the previous dask graph and put it onto the cluster
ds_1yr = xr.open_mfdataset('/fsx/2008/*/data/sea_surface_temperature.nc', chunks={'time0': 25})
dr_1yr = ds_1yr['sea_surface_temperature']
dr_1yr
%time mean_1yr_par = dr_1yr.mean().compute()
The bandwidth like is 29GB/5s ~ 6 GB/s.
mean_1yr_par.equals(mean_1yr_ser) # consistent with serial result
len(dr_1yr.chunks[0]) # there are actually not that many chunks for dask, but it is still super fast
Read multi-year data¶
For this part you might get "Too many files open" error. If so, run sudo sh -c "ulimit -n 65535 && exec su $LOGNAME" to raise the limit before starting Jupyter (ref: https://stackoverflow.com/a/17483998).
file_list = [f'/fsx/{year}/{month:02d}/data/sea_surface_temperature.nc'
for year in range(2008, 2018) for month in range(1, 13)]
len(file_list) # many files
file_list[0:3], file_list[-3:]
# Will cause data being pulled from S3 to FSx.
# Will take a long time at first executation. Much faster at second time.
%time ds_10yr = xr.open_mfdataset(file_list, chunks={'time0': 50})
dr_10yr = ds_10yr['sea_surface_temperature']
dr_10yr
Near 300 GB!
dr_10yr.nbytes / 1e9 # GB
%time mean_10yr = dr_10yr.mean().compute()
260 GB/ 24s = 10 GB/s ?!?!?
%time ts_10yr = dr_10yr.mean(dim=['lat', 'lon']).compute()
ts_10yr.plot(size=6)









